//-*-C++-*-

/***************************************************************************
 *
 *   Copyright (C) 2024 by Will Gauvin
 *   Licensed under the Academic Free License version 2.1
 *
 ***************************************************************************/

#ifndef __dsp_RescaleCUDA_h
#define __dsp_RescaleCUDA_h

#include "dsp/Rescale.h"
#include "dsp/LaunchConfig.h"
#include "dsp/Scratch.h"

namespace CUDA
{
  class RescaleEngine : public dsp::Rescale::Engine
  {
  public:
    RescaleEngine(cudaStream_t stream = 0);
    ~RescaleEngine();

    /**
     * @brief initialise the the engine.
     *
     * @param input a pointer to the input timeseries
     * @param nsample the number of samples used to calculate the scale and offsets for.
     * @param exact an indicator of whether to have exact number of samples to calculate the scales and offsets.
     * @param constant_offset_scale an indicator of whether to only calculate the scale and offset one or each time
     *    the transform method is called.
     */
    void init(const dsp::TimeSeries *input, uint64_t nsample, bool exact, bool constant_offset_scale) override;

    /**
     * @brief perform the transformation of the input timeseries data and write out to the given output timeseries
     *
     * @param input a pointer to the input timeseries data.
     * @param output a pointer to the output timeseries data.
     */
    void transform(const dsp::TimeSeries *input, dsp::TimeSeries *output) override;

    /**
     * @brief get the spectrum of offsets for the given polarization
     *
     * @param ipol the polarisation index to get spectrum of offsets for.
     *
     * @returns the spectrum of offsets for the given polarization
     */
    const float *get_offset(unsigned ipol) const override;

    /**
     * @brief get the spectrum of scales for the given polarization
     *
     * @param ipol the polarisation index to get spectrum of scales for.
     *
     * @returns the spectrum of scales for the given polarization
     */
    const float *get_scale(unsigned ipol) const override;

    /**
     * @brief Get the spectrum of integrated input values for the given polarization
     *
     * Integration is performed across nsample time samples.
     *
     * @param ipol the polarisation index to get spectrum of integrated inputs values for.
     * @returns the spectrum of integrated input values for the given polarization.
     */
    const double *get_freq_total(unsigned ipol) const override;

    /**
     * @brief Get the spectrum in integrated input values squared for the given polarization
     *
     * Integration is performed across the nsample time samples.
     *
     * @param ipol the polarisation index to get spectrum of the totale input values squared for the current nsample.
     * @returns the spectrum in integrated input values squared for the given polarization.
     */
    const double *get_freq_squared_total(unsigned ipol) const override;

  private:
    //! The CUDA stream in which operations will be scheduled
    cudaStream_t stream;

    //! gpu configuration
    LaunchConfig gpu_config;

    //! flag that ensures an exact number of samples, interval_samples, is used to compute offsets and scales
    bool exact{false};

    //! the number of time samples the transformation will assess when computing statistics
    uint64_t nsample{0};

    //! counter for the number of processed time samples
    uint64_t isample{0};

    //! the number of polarisations used
    unsigned npol{0};

    //! the number of channels used
    unsigned nchan{0};

    //! the number of dimensions used
    unsigned ndim{0};

    //! flag to track whether the first integration has been computed
    bool first_integration{false};

    //! flag that holds the offset and scale constant, after the first calculation
    bool constant_offset_scale{false};

    // Host arrays

    //! sum of time samples, ordered by [pol][chan]
    double *h_freq_total;

    //! sum of the square of the time samples, ordered by [pol][chan]
    double *h_freq_totalsq;

    //! normalisation scale, ordered by [pol][chan]
    float *h_scale;

    //! normalisation offset, ordered by [pol][chan]
    float *h_offset;

    // Device arrays

    //! sum of time samples, ordered by [chan][pol]
    float *d_freq_total;

    //! sum of the square of the time samples, ordered by [chan][pol]
    float *d_freq_totalsq;

    //! normalisation scale, ordered by [chan][pol]
    float *d_scale;

    //! normalisation offset, ordered by [chan][pol]
    float *d_offset;

    //! size of each freq total arrays in bytes
    size_t freq_size{0};

    //! size of scale and offset arrays in bytes
    size_t data_size_bytes{0};

    //! scratch used for handling of transforming host array layouts (PF-ordered) vs device array layouts (FP-ordered)
    Reference::To<dsp::Scratch> scratch;
  };
}

#endif // __dsp_RescaleCUDA_h
